# ------------------------------------------------------------------------------
# Trains ensemble model on validation set
# From: stress_test_medicare repo <https://gitlab.com/labsysmed/zolab-projects/stress_test_medicare/-/blob/master/code/03_analysis/01_build-models/06_train-ensemble.R>
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 06_train_ensemble.sh {outcome} {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(Matrix)
library(tidyverse)
library(glue)
library(optparse)
library(here) # here() relative filepaths
library(testit) # assert()
library(OneR) # bin()

u <- modules::use(here::here("lib", "util.R"))
dir <- here::here("code", "03_analysis", "01_build_models")

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--outcome", type = "character"),
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Helpers ----------------------------------------------------------------------
trim_glm <- function(model) {
  model$residuals <- NULL
  model$fitted.values <- NULL
  model$weights <- NULL
  model$prior.weights <- NULL

  return(model)
}

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
subscores_dir <- file.path(
  paths$modeling$dir, "subscores", arg_list$split, arg_list$restriction
)
models_dir <- file.path(
  paths$modeling$dir, "models", arg_list$split, arg_list$restriction
)

# Load Data --------------------------------------------------------------------
message("Loading data...")
config <- read_yaml(
  file.path(dir, "ensemble_config", glue("{arg_list$outcome}.yml"))
)
ids <- readRDS(glue(paths$modeling$val))
subscores <- readRDS(file.path(subscores_dir, glue("subscores_val_set.rds")))
assert("Rows of IDs = rows of subscores", nrow(ids) == nrow(subscores))

# Subset Data ------------------------------------------------------------------
message("Subsetting data...")
keep_pop <- switch(config$population,
  all = rep(TRUE, nrow(ids)),
  tested = ids$test_010_day == TRUE,
  untested = ids$test_010_day == FALSE
)
# keep_pop <- !ids$test_010_day
keep_pop <- keep_pop & !ids$exclude_modeling
ids <- ids[keep_pop, ]
subscores <- subscores[keep_pop, ]

med_stent_gbm <- median(subscores[["p__gbm__stent_or_cabg_010_day__tested"]])
train_df <- ids %>%
  u$safe_left_join(subscores) %>%
  # mutate(high_gbm_stent = p__gbm__stent_or_cabg_010_day__tested > med_stent_gbm) %>%
  setDT()

# use p (probability) or z (log) variables
# according to whether using logit or OLS
if (config$logit) {
  covariates <- map(
    config$components, ~ (
      str_replace(.x, "gbm__", "z__gbm__") %>%
      str_replace("lasso__", "z__lasso__")
    )
  ) %>% unlist()
} else {
  covariates <- map(
    config$components, ~ (
      str_replace(.x, "gbm__", "p__gbm__") %>%
      str_replace("lasso__", "p__lasso__")
    )
  ) %>% unlist()
}
if(arg_list$restriction != "all"){
  covariates <- glue("{covariates}__{arg_list$restriction}")
}

# Fitting ----------------------------------------------------------------------
message("Fitting models")
if (!config$constant) {
  covariates <- unlist(c(covariates, "-1"))
}

model_formula <- reformulate(
  response = config$target,
  termlabels = covariates
)

model_family <- ifelse(config$logit, "binomial", "gaussian")

model <- glm(model_formula,
  data = train_df, family = model_family,
  model = FALSE, y = FALSE,
  control = list(maxit = 50, trace = TRUE)
)
model <- trim_glm(model)

# Save Data --------------------------------------------------------------------
message("Saving results...")
model_path <- file.path(
  models_dir,
  glue("ensemble__{arg_list$outcome}.rds")
)

write_rds(model, model_path)

# Done -------------------------------------------------------------------------
message("Done.")
